package top.lingkang.mm.utils;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.text.AntPathMatcher;
import cn.hutool.core.util.ClassUtil;
import lombok.extern.slf4j.Slf4j;
import top.lingkang.mm.error.MagicException;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLConnection;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @Author lingkang
 * @Date 2024/2/29 16:41
 */
@Slf4j
public class MagicUtils {
    private static final AntPathMatcher antPathMatcher = new AntPathMatcher();
    private static final Map<String, List<String>> scanCache = new HashMap<>();

    /**
     * 扫描路径
     *
     * @param scanPath 扫描的路径，例如 mapper/*.xml、mapper/**.xml
     */
    public static List<String> scanResource(String scanPath) {
        List<String> list = scanCache.get(scanPath);
        if (list != null)
            return list;
        URL url = MagicUtils.class.getClassLoader().getResource("top");
        List<String> result = new ArrayList<>();
        try {
            if (url != null) {
                JarFile jarFile = null;
                URLConnection con = url.openConnection();
                if (con instanceof JarURLConnection) {
                    JarURLConnection jarCon = (JarURLConnection) con;
                    jarFile = jarCon.getJarFile();
                } else {
                    // 手动接收结果
                    String urlFile = url.getFile();
                    int separatorIndex = urlFile.indexOf("*/");// tomcat
                    if (separatorIndex == -1) {
                        separatorIndex = urlFile.indexOf("!/");// jar
                    }
                    if (separatorIndex != -1) {
                        // String jarFileUrl = urlFile.substring(0, separatorIndex);
                        String rootEntryPath = urlFile.substring(separatorIndex + 2);  // both separators are 2 chars
                        jarFile = new JarFile(rootEntryPath);
                    }
                }

                // 遍历
                if (jarFile != null) {
                    Enumeration<JarEntry> entries = jarFile.entries();
                    while (entries.hasMoreElements()) {
                        JarEntry entry = entries.nextElement();
                        String entryName = entry.getName();

                        // 检查条目是否位于指定的目录下
                        if (antPathMatcher.match(scanPath, entryName)) {
                            result.add(entryName);
                        }
                    }
                    jarFile.close();
                } else {
                    // 直接遍历，此时可能是idea、eclipse开发环境。
                    URL resource = MagicUtils.class.getClassLoader().getResource("");
                    if (resource != null) {
                        File file = new File(resource.getPath() + scanPath);
                        if (file.listFiles() != null)
                            for (File f : file.listFiles()) {
                                result.add(f.getPath());
                            }
                    }
                }
            }
        } catch (IOException e) {
            throw new MagicException(e);
        }
        scanCache.put(scanPath, result);
        return result;
    }

    /**
     * 获得本类及其父类所有Public方法
     *
     * @param clazz 入参类
     * @return 本类及其父类所有Public方法
     */
    public static Method[] getAllPublicMethod(Class<?> clazz) {
        return ClassUtil.getPublicMethods(clazz);
    }

    /**
     * 获得本类及其父类所有属性，若子类属性与父类属性同名，将只取子类属性
     *
     * @param clazz 类
     * @return 所有属性 field
     */
    public static Field[] getAllField(Class<?> clazz) {
        List<Field> list = new ArrayList<>(Arrays.asList(clazz.getDeclaredFields()));
        Class<?> superclass = clazz.getSuperclass();
        while (superclass != null) {
            for (Field field : superclass.getDeclaredFields())
                if (!existsFieldName(field, list)) {// 若子类属性与父类属性同名，将只取子类属性
                    list.add(field);
                }
            superclass = superclass.getSuperclass();
        }
        return list.toArray(new Field[0]);
    }

    private static boolean existsFieldName(Field field, List<Field> list) {
        for (Field f : list)
            if (f.getName().equals(field.getName()))
                return true;
        return false;
    }


    /**
     * 执行sql脚本，事务中进行，执行完毕将关闭连接
     * <pre>
     * {@code
     * // 执行初始化sql脚本，若需要执行的话
     * Connection connection = sqlSession.getConnection();
     * String script = IoUtil.read(getClass().getClassLoader().getResourceAsStream("script/init-mysql.sql"), StandardCharsets.UTF_8);
     * MagicUtils.exeScript(script, connection);
     * }
     * </pre>
     *
     * @param sqlScriptFile sql脚本文件
     * @param connection    数据库连接
     */
    public static void exeScript(File sqlScriptFile, Connection connection) {
        exeScript(FileUtil.readString(sqlScriptFile, StandardCharsets.UTF_8), connection);
    }

    /**
     * 执行sql脚本，事务中进行，执行完毕将关闭连接
     * <pre>
     * {@code
     * // 执行初始化sql脚本，若需要执行的话
     * Connection connection = sqlSession.getConnection();
     * String script = IoUtil.read(getClass().getClassLoader().getResourceAsStream("script/init-mysql.sql"), StandardCharsets.UTF_8);
     * MagicUtils.exeScript(script, connection);
     * }
     * </pre>
     *
     * @param sqlScript  sql脚本内容
     * @param connection 数据库连接
     */
    public static void exeScript(String sqlScript, Connection connection) {
        long start = System.currentTimeMillis();
        String tempSql = "";
        int number = 1;
        try {
            // 处理注释
            sqlScript = sqlScript.replaceAll("-- .*", "");
            sqlScript = sqlScript.replaceAll("--.*", "");
            Pattern pattern = Pattern.compile("/\\*([^*]|[\\r\\n]|(\\*+([^*/]|[\\r\\n])))*\\*+/",
                    Pattern.DOTALL | Pattern.MULTILINE);
            Matcher matcher = pattern.matcher(sqlScript);
            sqlScript = matcher.replaceAll("");
            // 移除换行
            sqlScript = sqlScript.replace("\n", "");
            sqlScript = sqlScript.replace(System.getProperty("line.separator", "\n"), "");
            String[] split = sqlScript.split(";");

            log.info("开始执行sql脚本");
            connection.setAutoCommit(false);
            for (String sql : split) {
                if (sql.length() < 10)
                    continue;
                tempSql = sql;
                PreparedStatement ps = connection.prepareStatement(sql);
                ps.executeUpdate();
                ps.close();
                number++;
            }
            connection.commit();
        } catch (Exception e) {
            log.error("执行脚本出错，事务回滚。执行错误的sql是：\n{}", tempSql);
            try {
                connection.rollback();
            } catch (SQLException ex) {
                throw new MagicException(ex);
            }
            throw new MagicException(e);
        } finally {
            IoUtil.close(connection);
            log.info("sql 执行完毕，共执行sql {} 个，耗时：{} ms", number, System.currentTimeMillis() - start);
        }
    }

    public static String getDatabaseURL(Connection connection, boolean isClose) {
        try {
            DatabaseMetaData metaData = connection.getMetaData();
            return metaData.getURL();
        } catch (Exception e) {
            throw new MagicException(e);
        } finally {
            if (isClose)
                IoUtil.close(connection);
        }
    }

    public static Object getValue(Field field, Object val) {
        try {
            return field.get(val);
        } catch (IllegalAccessException e) {
            return null;
        }
    }
}
